{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# [2、PyTorch张量的运算API（上）](https://www.bilibili.com/video/BV1wQ4y1q7Bm?spm_id_from=333.788.player.switch&vd_source=cdd897fffb54b70b076681c3c4e4d45d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.8978, 0.1323],\n",
       "        [0.7349, 0.8007],\n",
       "        [0.4459, 0.7735]])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b = torch.rand(3,2)\n",
    "b\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[0.8978, 0.1323],\n",
       "         [0.7349, 0.8007]]),\n",
       " tensor([[0.4459, 0.7735]]))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 按照行分割成2个张量\n",
    "torch.chunk(b, chunks=2, dim=0)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[0.8978],\n",
       "         [0.7349],\n",
       "         [0.4459]]),\n",
       " tensor([[0.1323],\n",
       "         [0.8007],\n",
       "         [0.7735]]))"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 按照行分割成2个张量\n",
    "torch.chunk(b, chunks=2, dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1, 2],\n",
      "        [3, 4]])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([[1, 1],\n",
       "        [4, 3]])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t = torch.tensor([[1, 2], [3, 4]])\n",
    "print(t)\n",
    "# dim=1：按照列收集\n",
    "# [0,0]：t[0, index[0,0]] = t[0, 0] = 1\n",
    "# [0,1]：t[0, index[0,1]] = t[0, 0] = 1\n",
    "# [1,0]：t[1, index[1,0]] = t[1, 1] = 4\n",
    "# [1,1]：t[1, index[1,1]] = t[1, 0] = 3\n",
    "torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0, 1, 2],\n",
       "        [3, 4, 5]])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# reshape改变张量的形状（不改变实际内容）\n",
    "a = torch.arange(6)\n",
    "torch.reshape(a, (2,3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0, 1],\n",
       "        [2, 3],\n",
       "        [4, 5]])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.reshape(a, (3,2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0, 1, 2, 3, 4, 5])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# reshape变成一维\n",
    "torch.reshape(a, (-1,))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0., 1.],\n",
      "        [2., 3.]])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([[0., 1.],\n",
       "        [0., 0.]])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 类似 gather，一个，但是会写入数据\n",
    "src = torch.zeros((2,2), dtype=torch.float32)\n",
    "dest = torch.arange(4, dtype=torch.float32).reshape((2,2))\n",
    "print(dest)\n",
    "dest.scatter(1, index=torch.tensor([[0, 0], [1, 0]]), src=src)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[0, 1],\n",
       "         [2, 3]]),\n",
       " tensor([[4, 5],\n",
       "         [6, 7]]),\n",
       " tensor([[8, 9]]))"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 和chunk不同，split指定大小，chunk指定划分多少个\n",
    "a = torch.arange(10).reshape((5,2))\n",
    "a.split(2, dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([0, 1, 2, 3]),\n",
       " tensor([4, 5, 6]),\n",
       " tensor([7, 8, 9]),\n",
       " tensor([10]),\n",
       " tensor([11]))"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# split可以指定分成几分\n",
    "a = torch.arange(12).reshape((-1,))\n",
    "a.split([4,3,3,1,1], dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[[[ 0],\n",
      "           [ 1]],\n",
      "\n",
      "          [[ 2],\n",
      "           [ 3]],\n",
      "\n",
      "          [[ 4],\n",
      "           [ 5]]],\n",
      "\n",
      "\n",
      "         [[[ 6],\n",
      "           [ 7]],\n",
      "\n",
      "          [[ 8],\n",
      "           [ 9]],\n",
      "\n",
      "          [[10],\n",
      "           [11]]]]]) torch.Size([1, 2, 3, 2, 1])\n",
      "tensor([[[ 0,  1],\n",
      "         [ 2,  3],\n",
      "         [ 4,  5]],\n",
      "\n",
      "        [[ 6,  7],\n",
      "         [ 8,  9],\n",
      "         [10, 11]]]) torch.Size([2, 3, 2])\n"
     ]
    }
   ],
   "source": [
    "# squeeze对多余的维度压缩\n",
    "a = torch.arange(12).reshape((1,2,3,2,1))\n",
    "print(a, a.shape)\n",
    "a.squeeze_()\n",
    "print(a, a.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([2, 3, 2, 1])\n"
     ]
    }
   ],
   "source": [
    "# squeeze 指定维度\n",
    "a = torch.arange(12).reshape((1,2,3,2,1))\n",
    "a.squeeze_(dim=0)\n",
    "print(a.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[1., 1.],\n",
       "         [1., 1.],\n",
       "         [1., 1.]],\n",
       "\n",
       "        [[2., 2.],\n",
       "         [2., 2.],\n",
       "         [2., 2.]]])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# stack：沿着新维度拼接起来\n",
    "a = torch.ones((3,2))\n",
    "b = torch.ones((3,2)) * 2\n",
    "torch.stack([a,b]) # 变成3维 (2,3,2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[1., 1.],\n",
       "         [2., 2.]],\n",
       "\n",
       "        [[1., 1.],\n",
       "         [2., 2.]],\n",
       "\n",
       "        [[1., 1.],\n",
       "         [2., 2.]]])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.stack([a,b], dim=1) # 变成3维 (3,2,2)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dl",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
